import abc
import numpy as onp
import torch
import time
import os
import argparse
import json
from lib.data.pems import PeMS04, PeMS08
from typing import List, cast
from lib.exec.model import encoderize, TIMESTAMPED, NOTEMBEDS, PRETAINABLES
from lib.exec.utils.floatrep import floatrep
import os
from lib.task.trafficx import TrafficCross
from lib.framework.dyngraph.dyngraph import FrameworkDynamicGraph
from lib.framework.dyngraph.sdgnn import FrameworkSdgnn
from lib.task.regression import RMSE, MAPE
from lib.framework.types import TIMECOST
from lib.framework.transfer import transfer
import more_itertools as xitertools
from lib.model.gnnx2 import graphicalize, GNNx2Concat
from lib.model.snn import sequentialize
from lib.model.gnnx2 import graphicalize
import scipy.sparse as sp
# #
# META = TypeVar("META", bound=Meta)


#
DIR = "log"
LR_DECAY = 10
LR_THRES = 1e-4 * (1 + 1 / LR_DECAY)
IMP_ABS = 0
IMP_REL = 0



def identifier(**KWARGS) -> str:
    R"""
    Get identifier of given arguments.
    """
    #
    return (
        "{:s}~{:s}~{:s}_{:s}~{:s}~{:s}_{:s}~{:s}~{:s}~{:s}_{:s}".format(
            "{:s}~{:s}".format(
                cast(str, KWARGS["source"]), cast(str, KWARGS["train_prop"]),
            )
            if "train_prop" in KWARGS else
            cast(str, KWARGS["source"]),
            cast(str, KWARGS["target"]),
            {"transductive": "trans", "inductive": "induc"}
            [cast(str, KWARGS["framework"])],
            "{:s}-preseq".format(cast(str, KWARGS["model"]))
            if (
                (
                    "pretrain_seq_node" in KWARGS
                    and len(KWARGS["pretrain_seq_node"]) > 0
                )
                or (
                    "pretrain_seq_edge" in KWARGS
                    and len(KWARGS["pretrain_seq_edge"]) > 0
                )
            ) else
            cast(str, KWARGS["model"]),
            cast(str, KWARGS["hidden"]),
            cast(str, KWARGS["activate"]), cast(str, KWARGS["lr"]),
            cast(str, KWARGS["weight_decay"]), cast(str, KWARGS["clipper"]),
            cast(str, KWARGS["patience"]), cast(str, KWARGS["seed"]),
        )
    )


parser = argparse.ArgumentParser()

log_base_dir = './log/pem08_target/'
json_dir = './log/pem08_target/args_PeMS08~all~induc_GRUoGCN2x2~16~softplus_0.001~1.0e-5~value~30_56.json'

with open(json_dir, 'rt') as f:
    d = json.load(f)

args = argparse.Namespace(**d)


# Localize arguments.
source = args.source
target = args.target
#
neuralname = args.model
hidden = args.hidden
activate = args.activate
pretrain_seq_node = args.pretrain_seq_node
pretrain_seq_edge = args.pretrain_seq_edge
lr = args.lr
train_prop = args.train_prop
weight_decay = args.weight_decay
clipper = args.clipper
patience = args.patience
seed = args.seed
device = args.device
resume_eval = args.resume_eval
frame = args.framework


if device != 'cuda:0':
    device = 'cuda:0'
    args.device = device


pretrain_seq_node = os.path.join(log_base_dir, 'snn_node.pth')
pretrain_seq_edge = os.path.join(log_base_dir, 'snn_edge.pth')
pretrain_gnn = os.path.join(log_base_dir, 'gnnx2.pth')
pretrain_low_level_mlp = os.path.join(log_base_dir, 'lower_level_mlp.pth')



# Constant arguments.
# TGAT will run out of memory on the testing machine.
# Temporarily reduce to 128.
num_batch_graphs = 128
args.num_batch_graphs = num_batch_graphs
#
if len(train_prop) > 0:
    #
    (train_prop_num_str, train_prop_den_str) = train_prop.split("d")
    train_prop_num_str = train_prop_num_str.replace("n", "-")
    train_prop_num = int(train_prop_num_str)
    train_prop_den = int(train_prop_den_str)
    train_prop_neg = train_prop_num < 0
    train_prop_num = -train_prop_num if train_prop_neg else train_prop_num
else:
    #
    train_prop_num = 0
    train_prop_den = 0
    train_prop_neg = False
train_prop_tuple = (train_prop_num, train_prop_den, train_prop_neg)


# Translate arguments.
datasetize = {"PeMS04": PeMS04, "PeMS08": PeMS08}[source]
spindle = {"transductive": "node", "inductive": "time"}[frame]
if target == "all":
    #
    targeton = [0, 1, 2]
else:
    #
    targeton = [{"flow": 0, "occupy": 1, "speed": 2}[target]]

#
attach_edge_time: List[str]
attach_node_time: List[str]



# Get neural network basement.
neuralitems = neuralname.split("-")



if len(neuralitems) > 1:
    #
    if len(neuralitems) > 2 or neuralitems[1] != "Seqed":
        # UNEXPECT:
        # Improper neural network suffix.
        raise NotImplementedError(
            "Improper neural network suffix \"{:s}\" for \"{:s}\"."
            .format(neuralitems[0], "-".join(neuralitems[1:])),
        )
    neuralbase = neuralitems[0]
else:
    #
    neuralbase = neuralname

#
if neuralbase in TIMESTAMPED:
    #
    extend_edge_time = TIMESTAMPED[neuralbase]["edge"]
    extend_node_time = TIMESTAMPED[neuralbase]["node"]
else:
    #
    extend_edge_time = []
    extend_node_time = []
attach_edge_time = []
attach_node_time = []
# Argument description identifier.
if len(train_prop) > 0:
    #
    desckws = {"train_prop": train_prop}
else:
    #
    desckws = {}

id = identifier(
    source=source, target=target, framework=frame,
    model=(
        "{:s}-Seqed".format(neuralname)
        if (
                neuralname in PRETAINABLES
                and (
                        len(pretrain_seq_node) > 0
                        or len(pretrain_seq_edge) > 0
                )
        ) else
        neuralname
    ),
    hidden=str(hidden), activate=activate, lr=floatrep(lr),
    weight_decay=floatrep(weight_decay), clipper=clipper,
    patience=str(patience), seed=str(seed),
    **desckws,
)

desc = (
    identifier(
        source=source, target=target, framework=frame,
        model=(
            "{:s}-Seqed".format(neuralname)
            if (
                neuralname in PRETAINABLES
                and (
                    len(pretrain_seq_node) > 0
                    or len(pretrain_seq_edge) > 0
                )
            ) else
            neuralname
        ),
        hidden=str(hidden), activate=activate, lr=floatrep(lr),
        weight_decay=floatrep(weight_decay), clipper=clipper,
        patience=str(patience), seed=str(seed),
        **desckws,
    )
)



#
frame = args.framework
train_prop = args.train_prop
lr = args.lr
weight_decay = args.weight_decay
clipper = args.clipper
patience = args.patience
seed = args.seed
device = args.device
resume_eval = args.resume_eval
continue_epoch = 1000 - args.epoch

num_batch_graphs = 32
args.num_batch_graphs = num_batch_graphs
#
if len(train_prop) > 0:
    #
    (train_prop_num_str, train_prop_den_str) = train_prop.split("d")
    train_prop_num_str = train_prop_num_str.replace("n", "-")
    train_prop_num = int(train_prop_num_str)
    train_prop_den = int(train_prop_den_str)
    train_prop_neg = train_prop_num < 0
    train_prop_num = -train_prop_num if train_prop_neg else train_prop_num
else:
    #
    train_prop_num = 0
    train_prop_den = 0
    train_prop_neg = False
train_prop_tuple = (train_prop_num, train_prop_den, train_prop_neg)



# Translate arguments.
datasetize = {"PeMS04": PeMS04, "PeMS08": PeMS08}[source]
spindle = {"transductive": "node", "inductive": "time"}[frame]
if target == "all":
    #
    targeton = [0, 1, 2]
else:
    #
    targeton = [{"flow": 0, "occupy": 1, "speed": 2}[target]]

#
attach_edge_time: List[str]
attach_node_time: List[str]

# Get neural network basement.
neuralitems = neuralname.split("-")



# Get neural network basement.
neuralitems = neuralname.split("-")
if len(neuralitems) > 1:
    #
    if len(neuralitems) > 2 or neuralitems[1] != "Seqed":
        # UNEXPECT:
        # Improper neural network suffix.
        raise NotImplementedError(
            "Improper neural network suffix \"{:s}\" for \"{:s}\"."
            .format(neuralitems[0], "-".join(neuralitems[1:])),
        )
    neuralbase = neuralitems[0]
else:
    #
    neuralbase = neuralname

#
if neuralbase in TIMESTAMPED:
    #
    extend_edge_time = TIMESTAMPED[neuralbase]["edge"]
    extend_node_time = TIMESTAMPED[neuralbase]["node"]
else:
    #
    extend_edge_time = []
    extend_node_time = []
attach_edge_time = []
attach_node_time = []
# Argument description identifier.
if len(train_prop) > 0:
    #
    desckws = {"train_prop": train_prop}
else:
    #
    desckws = {}



id = identifier(
    source=source, target=target, framework=frame,
    model=(
        "{:s}-Seqed".format(neuralname)
        if (
                neuralname in PRETAINABLES
                and (
                        len(pretrain_seq_node) > 0
                        or len(pretrain_seq_edge) > 0
                )
        ) else
        neuralname
    ),
    hidden=str(hidden), activate=activate, lr=floatrep(lr),
    weight_decay=floatrep(weight_decay), clipper=clipper,
    patience=str(patience), seed=str(seed),
    **desckws,
)


desc = (
    identifier(
        source=source, target=target, framework=frame,
        model=(
            "{:s}-Seqed".format(neuralname)
            if (
                neuralname in PRETAINABLES
                and (
                    len(pretrain_seq_node) > 0
                    or len(pretrain_seq_edge) > 0
                )
            ) else
            neuralname
        ),
        hidden=str(hidden), activate=activate, lr=floatrep(lr),
        weight_decay=floatrep(weight_decay), clipper=clipper,
        patience=str(patience), seed=str(seed),
        **desckws,
    )
)


# Prepare PeMS dataset.
print("=" * 10 + " " + "Data & Meta" + " " + "=" * 10)
dataset = (
    datasetize(
        os.path.join("src", source),
        aug_minutes=True, aug_weekdays=True,
    )
)

# Formalize as future prediction task.
# Predict 1 future frame by 1 hour history (12 historical frames).
metaset = (
    dataset.asto_dynamic_adjacency_list_static_edge(
        window_history_size=60 // 5, window_future_size=1,
        timestamped_edge_times=extend_edge_time,
        timestamped_node_times=extend_node_time,
        timestamped_edge_feats=attach_edge_time,
        timestamped_node_feats=attach_node_time,
    )
)
metaset.inputon(["all", "none", "all", "none"])
metaset.targeton(["none", "none", targeton, "none"])
print(metaset)



# Prepare PeMS model.
print("=" * 10 + " " + "Model & Task" + " " + "=" * 10)
neuralnet = (
    TrafficCross(
        encoderize(
            neuralbase, metaset.edge_feat_size, metaset.node_feat_size,
            hidden, hidden, activate,
            dyn_edge=False, dyn_node=True,
            num_nodes=metaset.num_nodes * num_batch_graphs,
            tid_ax=metaset.node_feat_size,
        ),
        len(targeton), hidden,
        activate=activate,
        notembedon=targeton if neuralbase in NOTEMBEDS else [],
    )
)

neuralnet.initialize(seed)



if neuralname in PRETAINABLES:
    #
    neuralnet.tgnn.continue_train_with_pretrain_node_model(pretrain_seq_node)
    neuralnet.tgnn.continue_train_with_pretrain_gnn_model(pretrain_gnn)
    neuralnet.continue_train_with_pretrain_mlp(pretrain_low_level_mlp)
    # neuralnet.tgnn.pretrain("edge", pretrain_seq_edge)

print(neuralnet)


# Prepare framework.
framework = (
    FrameworkSdgnn(
        desc, metaset, neuralnet,
        lr=lr, weight_decay=weight_decay, seed=seed, device=device,
        metaspindle=spindle, gradclip=clipper, doc_num=9
    )
)
framework.set_node_batching(False)



framework.fit_sdgnn(
    (7, 1, 2), (2, 1, 0), train_prop_tuple,
    batch_size=num_batch_graphs, max_epochs=4000, validon=MAPE,
    validrep="MAPE", patience=patience,
)

framework.fit_low_level_mlp(
    (7, 1, 2), (2, 1, 0), train_prop_tuple,
    batch_size=num_batch_graphs, max_epochs=100, validon=MAPE,
    validrep="MAPE", patience=patience,
)

framework.besteval_sdgnn(    (7, 1, 2), (2, 1, 0), train_prop_tuple,
    batch_size=num_batch_graphs, validon=MAPE,
    validrep="MAPE", resume=resume_eval,
)